import soapy
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from astropy.io import fits
import copy

from Global_variable_setting import image_size, patch_size


# define custome exception

class RepeatProcessing(Exception):
    def __init__(self):
        super().__init__("This model has been tested by this program. There is no need to test again. "
                         "And this program will exit for the safe of the existing result data!")


class NoDataError(Exception):
    def __init__(self):
        super().__init__('There is no data for test in this directory!')


# preparing function for normalization
def normalize(data):
    size_of_patch_one_dimension = image_size // patch_size
    for i in range(size_of_patch_one_dimension):
        for j in range(size_of_patch_one_dimension):
            local_area = data[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
            if np.max(local_area) != 0:
                local_area = (local_area - np.min(local_area)) / (np.max(local_area) - np.min(local_area))
                data[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = local_area
            else:
                continue


work_dir = os.getcwd()
config_file = work_dir + '/configuration_dir/generate_data.yaml'

sim = soapy.Sim(config_file)
sim.aoinit()
sim.makeIMat()


# Enter essential parameters for test

while True:
    dataset_type = input('Enter dataset type (normal or perfect):')
    if dataset_type in ['normal', 'perfect']:
        break
if dataset_type == 'perfect':
    model_path_inter = work_dir + '/result/training_result/good_data_for_test'
    result_path_inter = work_dir + '/result/soapy_test_result/good_data_for_test'
else:
    model_path_inter = work_dir + '/result/training_result'
    result_path_inter = work_dir + '/result/soapy_test_result'


while True:
    model_type = input('Enter model type (Vision Transformer or CNN):')
    if model_type in ['Vision Transformer', 'CNN']:
        break
model_path_inter = model_path_inter + '/' + model_type
result_path_inter = result_path_inter + '/' + model_type


while True:
    normalized = input('Enter if this modal is normalized (True of False):')
    if normalized in ['True', 'False']:
        break
if normalized == 'True':
    model_path_inter = model_path_inter + '/normalized'
    result_path_inter = result_path_inter + '/normalized'
else:
    model_path_inter = model_path_inter + '/unnormalized'
    result_path_inter = result_path_inter + '/unnormalized'


while True:
    if len(os.listdir(model_path_inter)) == 0:
        raise NoDataError
    model_index = input(f'Enter the index of model ( in the range of [0, {len(os.listdir(model_path_inter))})):')
    if int(model_index) in range(len(os.listdir(model_path_inter))):
        break
model_path = model_path_inter + '/' + model_index + '/model'
result_path = result_path_inter + '/' + model_index


# make result path

if not os.path.exists(result_path):
    os.makedirs(result_path)
else:
    # print("This model has been tested by this program. There is no need to test again. "
    #       "And this program will exit for the save of the existing result data!")
    raise RepeatProcessing


# load modal
model = tf.saved_model.load(model_path)
# model = tf.keras.models.load_model(model_path)

# test the model performance of different magnitude firstly
test_diff_mag_dir = os.getcwd() + '/testing_data/different_mag'
diff_mag = {}
for folder in os.listdir(test_diff_mag_dir):
    diff_mag[folder] = []
    for i in range(1000):
        element_dir = test_diff_mag_dir + '/' + folder + '/' + str(i)
        scrns = fits.open(element_dir + '/atmos_scrns.fits')[0].data
        scrns = scrns.astype('float64')
        if dataset_type == 'normal':
            sh_frame = fits.open(element_dir + '/distorted_detector.fits')[0].data
        else:
            sh_frame = fits.open(element_dir + '/perfect_detector.fits')[0].data
        sh_frame = sh_frame.astype('float32')
        if normalized == 'True':
            normalize(sh_frame)
        sh_frame = sh_frame[..., np.newaxis]
        sh_frame = sh_frame[np.newaxis, ...]
        dm_command = model(sh_frame)
        dm_command = dm_command[0]
        # dm_command = np.ones(289)
        dm_shape = sim.dms[0].dmFrame(dm_command)
        dm_shape = dm_shape[np.newaxis, ...]
        sim.sciCams[0].frame(scrns, dm_shape)
        inst_strehl = copy.deepcopy(sim.sciCams[0].instStrehl)

        diff_mag[folder].append(inst_strehl)

        print(f'Processing diff_mag data ({folder}): {i}, the strehl result is {inst_strehl}')

# test the model performance of different r0 secondly
test_diff_r0_dir = os.getcwd() + '/testing_data/different_r0'
diff_r0 = {}
for folder in os.listdir(test_diff_r0_dir):
    diff_r0[folder] = []
    for i in range(1000):
        element_dir = test_diff_r0_dir + '/' + folder + '/' + str(i)
        scrns = fits.open(element_dir + '/atmos_scrns.fits')[0].data
        scrns = scrns.astype('float64')
        if dataset_type == 'normal':
            sh_frame = fits.open(element_dir + '/distorted_detector.fits')[0].data
        else:
            sh_frame = fits.open(element_dir + '/perfect_detector.fits')[0].data
        sh_frame = sh_frame.astype('float32')
        if normalized == 'True':
            normalize(sh_frame)
        sh_frame = sh_frame[..., np.newaxis]
        sh_frame = sh_frame[np.newaxis, ...]
        dm_command = model(sh_frame)
        dm_command = dm_command[0]
        # dm_command = np.ones(289)
        dm_shape = sim.dms[0].dmFrame(dm_command)
        dm_shape = dm_shape[np.newaxis, ...]
        sim.sciCams[0].frame(scrns, dm_shape)
        inst_strehl = copy.deepcopy(sim.sciCams[0].instStrehl)

        diff_r0[folder].append(inst_strehl)

        print(f'Processing diff_r0 data ({folder}): {i}, the strehl result is {inst_strehl}')

# begin the process of printing and saving
diff_mag_values = [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
diff_r0_values = [0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20]
diff_mag_instStrehl_mean = []
diff_mag_instStrehl_std = []
diff_r0_instStrehl_mean = []
diff_r0_instStrehl_std = []

for i in range(len(diff_mag_values)):
    diff_mag_instStrehl_mean.append(np.mean(diff_mag['star_mag='+str(diff_mag_values[i])]))
    diff_mag_instStrehl_std.append(np.std(diff_mag['star_mag='+str(diff_mag_values[i])]))
for i in range(len(diff_r0_values)):
    diff_r0_instStrehl_mean.append(np.mean(diff_r0['r0='+str(diff_r0_values[i])]))
    diff_r0_instStrehl_std.append(np.std(diff_r0['r0='+str(diff_r0_values[i])]))


# plot figure and save it
plt.figure()
plt.errorbar(diff_mag_values, diff_mag_instStrehl_mean, yerr=diff_mag_instStrehl_std)
plt.xlabel('star magnitude')
plt.ylabel('inst Strehl (mean+-std)')
plt.title('inst strehl vs diff star magnitude')
plt.tight_layout()
plt.savefig(result_path + '/diff_mag_vs_inststrehl')
plt.show()

plt.figure()
plt.errorbar(diff_r0_values, diff_r0_instStrehl_mean, yerr=diff_r0_instStrehl_std)
plt.xlabel('atmosphere r0 (cm)')
plt.ylabel('inst strehl (mean+-std)')
plt.title('inst strehl vs diff atmosphere r0')
plt.tight_layout()
plt.savefig(result_path + '/diff_r0_vs_inststrehl')
plt.show()










# def generate_single_test():
#     sim.atmos.randomScrns()
#     for _ in range(3):
#         plt.subplot(1, 3, _+1)
#         plt.imshow(sim.atmos.scrns[_], origin='lower')
#         plt.colorbar()
#         plt.title('scrn ' +str(_))
#     plt.tight_layout()
#     plt.show()
#
#     plt.imshow(sim.wfss[0].los.frame(sim.atmos.scrns) * sim.mask, origin='lower')
#     cbar = plt.colorbar()
#     plt.title('pupil phase with radius')
#     plt.show()
#
#     perfect_slopes = sim.wfss[0].frame(sim.atmos.scrns)
#     perfect_acts = sim.recon.control_matrix.T.dot(perfect_slopes)
#
#     plt.imshow(sim.dms[0].dmFrame(perfect_acts) * sim.mask * sim.wfss[0].los.phs2Rad, origin='lower',
#                vmax=cbar.vmax, vmin=cbar.vmin)
#     plt.title('dm screen with radius')
#     plt.colorbar()
#     plt.show()
#
#     dm_correction = sim.dms[0].dmFrame(perfect_acts)[np.newaxis, ...]
#     plt.imshow(sim.wfss[0].los.frame(sim.atmos.scrns, dm_correction) * sim.wfss[0].los.phs2Rad * sim.mask,
#                origin='lower', vmax=cbar.vmax, vmin=cbar.vmin)
#     plt.colorbar()
#     plt.title('phase after correction with radius')
#     plt.show()
#
#     plt.imshow(sim.sciCams[0].frame(sim.atmos.scrns), origin='lower')
#     plt.colorbar()
#     plt.title('image before correction')
#     plt.show()
#
#     plt.imshow(sim.sciCams[0].frame(sim.atmos.scrns, dm_correction), origin='lower')
#     plt.colorbar()
#     plt.title('image after correction')
#     plt.show()
#
#
# loop_times = 10
# for _ in range(loop_times):
#     generate_single_test()
